import argparse
import os
from pathlib import Path
import wandb
import yaml
import numpy as np
import torch
from cemcd.training import train_cem, train_hicem
from experiment_utils import load_config, load_datasets, get_intervention_accuracies

ALPHABET = [
    "ALPHA",
    "BRAVO",
    "CHARLIE",
    "DELTA",
    "ECHO",
    "FOXTROT",
    "GOLF",
    "HOTEL",
    "INDIA",
    "JULIET",
    "KILO",
    "LIMA",
    "MIKE",
    "NOVEMBER",
    "OSCAR",
    "PAPA",
    "QUEBEC",
    "ROMEO",
    "SIERRA",
    "TANGO",
    "UNIFORM",
    "VICTOR",
    "WHISKEY",
    "XRAY",
    "YANKEE",
    "ZULU"
]

def create_run_name(results_dir, dataset):
    for word1 in ALPHABET:
        for word2 in ALPHABET:
            for word3 in ALPHABET:
                run_name = f"{dataset}-{word1}-{word2}-{word3}"
                if not (Path(results_dir) / run_name).exists():
                    return run_name
    raise RuntimeError("All run names have been used.")

def get_accuracies(test_results, n_provided_concepts, model_name):
    task_accuracy = round(test_results['test_y_accuracy'], 4)
    provided_concept_accuracies = []
    discovered_concept_accuracies = []
    provided_concept_aucs = []
    discovered_concept_aucs = []
    for key, value in test_results.items():
        if key[:7] == "concept":
            n = int(key.split("_")[1])
            if n <= n_provided_concepts:
                if key[-3:] == "auc":
                    provided_concept_aucs.append(value)
                else:
                    provided_concept_accuracies.append(value)
            else:
                if key[-3:] == "auc":
                    discovered_concept_aucs.append(value)
                else:
                    discovered_concept_accuracies.append(value)

    provided_concept_accuracy = round(np.mean(provided_concept_accuracies), 4)
    provided_concept_auc = round(np.mean(provided_concept_aucs), 4)
    if len(discovered_concept_accuracies) > 0:
        discovered_concept_accuracy = round(np.mean(discovered_concept_accuracies), 4)
        discovered_concept_auc = round(np.mean(discovered_concept_aucs), 4)
    else:
        discovered_concept_accuracy = np.nan
        discovered_concept_auc = np.nan

    results = {
        f"{model_name}_task_accuracy": float(task_accuracy),
        f"{model_name}_provided_concept_accuracy": float(provided_concept_accuracy),
        f"{model_name}_provided_concept_accuracies": list(map(lambda x: round(float(x), 4), provided_concept_accuracies)),
        f"{model_name}_provided_concept_auc": float(provided_concept_auc),
        f"{model_name}_provided_concept_aucs": list(map(lambda x: round(float(x), 4), provided_concept_aucs))}
    if len(discovered_concept_accuracies) > 0:
        results.update({
            f"{model_name}_discovered_concept_accuracy": float(discovered_concept_accuracy),
            f"{model_name}_discovered_concept_accuracies": list(map(lambda x: round(float(x), 4), discovered_concept_accuracies)),
            f"{model_name}_discovered_concept_auc": float(discovered_concept_auc),
            f"{model_name}_discovered_concept_aucs": list(map(lambda x: round(float(x), 4), discovered_concept_aucs))})
    return results

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c", "--config", 
        type=str,
        required=True,
        help="Path to the experiment config file.")
    parser.add_argument(
        "-d", "--discovered-concept-file",
        type=str,
        required=True,
        help="Path to a YAML file listing discovered concept labels."
    )
    return parser.parse_args()

def get_provided_intervention_accuracies(models, datasets, model_name_prefix, model_type="hicem"):
    n_provided_concepts = datasets[0].n_concepts
    n_discovered_concepts = models[0].n_concepts - n_provided_concepts

    provided_concepts = range(n_provided_concepts)

    test_dataset_size = len(datasets[0].test_dl().dataset)

    results = {}
    for dataset, model in zip(datasets, models):
        model_name = f"{model_name_prefix}_{dataset.foundation_model}_{model_type}"
        results[f"{model_name}_provided_concept_interventions_cumulative"] = get_intervention_accuracies(
            model=model,
            test_dl=dataset.test_dl(np.full((test_dataset_size, n_discovered_concepts), np.nan)),
            concepts_to_intervene=provided_concepts,
            one_at_a_time=False)
        results[f"{model_name}_provided_concept_interventions_one_at_a_time"] = get_intervention_accuracies(
            model=model,
            test_dl=dataset.test_dl(np.full((test_dataset_size, n_discovered_concepts), np.nan)),
            concepts_to_intervene=provided_concepts,
            one_at_a_time=True)

    return results

def log_results(config, run_dir, results):
    with (Path(run_dir) / "results.yaml").open("a") as f:
        yaml.safe_dump(results, f)

    if config["use_wandb"]:
        wandb.log(results)


def evaluate(run_dir, config, discovered_concepts_file):
    run_dir = Path(run_dir)

    datasets = load_datasets(config)

    train_dataset_size = len(datasets[0].train_dl().dataset)
    val_dataset_size = len(datasets[0].val_dl().dataset)
    test_dataset_size = len(datasets[0].test_dl().dataset)

    log = lambda results: log_results(config, run_dir, results)

    with open(discovered_concepts_file, "r") as f:
        discovered_concept_files = yaml.safe_load(f)

    n_discovered_sub_concepts = []
    discovered_concept_labels_list = []
    for file in discovered_concept_files:
        labels = np.load(file)
        n_discovered_sub_concepts.append(labels.shape[1])
        discovered_concept_labels_list.append(labels)
    discovered_concept_labels = np.full((train_dataset_size, sum(n_discovered_sub_concepts)), np.nan)
    start = 0
    for labels in discovered_concept_labels_list:
        n = labels.shape[1]
        discovered_concept_labels[:, start:start+n] = labels
        start += n

    log({"n_discovered_sub_concepts": n_discovered_sub_concepts})

    sub_concepts = list(map(lambda n: (n, 0), n_discovered_sub_concepts)) # We don't split negative embeddings
    models_with_discovered_concepts = []
    for dataset in datasets:
        model, test_results = train_hicem(
            sub_concepts=sub_concepts,
            n_tasks=dataset.n_tasks,
            latent_representation_size=dataset.latent_representation_size,
            embedding_size=config["hicem_embedding_size"],
            concept_loss_weight=config["hicem_concept_loss_weight"],
            train_dl=dataset.train_dl(discovered_concept_labels),
            val_dl=dataset.val_dl(np.full((val_dataset_size, sum(n_discovered_sub_concepts)), np.nan)),
            test_dl=dataset.test_dl(np.full((test_dataset_size, sum(n_discovered_sub_concepts)), np.nan)),
            save_path=run_dir / f"enhanced_{dataset.foundation_model}_hicem.pth",
            max_epochs=config["max_epochs"],
            use_task_class_weights=config["use_task_class_weights"],
            use_concept_loss_weights=config["use_concept_loss_weights"])
        log(get_accuracies(test_results, dataset.n_concepts, f"enhanced_{dataset.foundation_model}_hicem"))
        models_with_discovered_concepts.append(model)

    # cems_with_discovered_concepts = []
    # for dataset in datasets:
    #     model, test_results = train_cem(
    #         n_concepts=dataset.n_concepts + sum(n_discovered_sub_concepts),
    #         n_tasks=dataset.n_tasks,
    #         latent_representation_size=dataset.latent_representation_size,
    #         embedding_size=config["cem_embedding_size"],
    #         concept_loss_weight=config["cem_concept_loss_weight"],
    #         train_dl=dataset.train_dl(discovered_concept_labels),
    #         val_dl=dataset.val_dl(np.full((val_dataset_size, sum(n_discovered_sub_concepts)), np.nan)),
    #         test_dl=dataset.test_dl(np.full((test_dataset_size, sum(n_discovered_sub_concepts)), np.nan)),
    #         save_path=run_dir / f"enhanced_{dataset.foundation_model}_cem.pth",
    #         max_epochs=config["max_epochs"],
    #         use_task_class_weights=config["use_task_class_weights"],
    #         use_concept_loss_weights=config["use_concept_loss_weights"])
    #     log(get_accuracies(test_results, dataset.n_concepts, f"enhanced_{dataset.foundation_model}_cem"))
    #     cems_with_discovered_concepts.append(model)

    log(get_provided_intervention_accuracies(
            models=models_with_discovered_concepts,
            datasets=datasets,
            model_name_prefix="enhanced"
    ))

    # log(get_provided_intervention_accuracies(
    #     models=cems_with_discovered_concepts,
    #     datasets=datasets,
    #     model_name_prefix="enhanced",
    #     model_type="cem"
    # ))

if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")
    args = parse_arguments()

    config = load_config(args.config)
    run_name = create_run_name(config["results_dir"], config["dataset"])
    print(f"RUN NAME: {run_name}\n")
    run_dir = Path(config["results_dir"]) / run_name
    run_dir.mkdir()
    (run_dir / "config.yaml").write_text(Path(args.config).read_text())
    if config["use_wandb"]:
        wandb.init(
            project="cem-concept-discovery-imagenet",
            config=config,
            name=run_name,
            notes=config["description"])
    evaluate(
        run_dir=run_dir,
        config=config,
        discovered_concepts_file=args.discovered_concept_file)

    if config["use_wandb"]:
        wandb.save(os.path.join(run_dir, "*"))
        wandb.finish()
